package com.yxsk.relay.job.component.endpoint.context;

import com.yxsk.relay.job.component.common.exception.net.NetHandleInterceptException;
import com.yxsk.relay.job.component.common.exception.serialization.MessageSerializeException;
import com.yxsk.relay.job.component.common.net.handler.interceptor.RpcHandleInterceptor;
import com.yxsk.relay.job.component.common.net.message.NetRequestMessage;
import com.yxsk.relay.job.component.common.net.message.NetResponseMessage;
import com.yxsk.relay.job.component.common.net.serialization.JsonMessageEntitySerializer;
import com.yxsk.relay.job.component.common.net.serialization.MessageEntitySerializer;
import com.yxsk.relay.job.component.common.protocol.message.BaseRequest;
import com.yxsk.relay.job.component.common.utils.CollectionUtils;
import com.yxsk.relay.job.component.common.utils.SerialNoUtils;
import com.yxsk.relay.job.component.endpoint.async.RelayJobAsyncWorker;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.jboss.netty.channel.MessageEvent;

import java.net.InetSocketAddress;
import java.util.Map;

/**
 * @Description
 * @Author 11376
 * @CreateTime 2019/9/9 8:34
 */
@Slf4j
@UtilityClass
public class RelayJobNetContextHolder {

    private MessageEntitySerializer SERIALIZER = new JsonMessageEntitySerializer();

    public String getRequestUri() {
        NetRequestMessage message = getContextMessage();
        return message == null ? null : message.getUri();
    }

    public String getRemoteAddress() {
        NetRequestMessage message = getContextMessage();
        if (message != null) {
            InetSocketAddress socketAddress = message.getSocketAddress();
            return socketAddress == null ? null : socketAddress.getAddress().getHostAddress();
        }
        return null;
    }

    public Map<String, String> getHeaders() {
        NetRequestMessage message = getContextMessage();
        return message == null ? null : message.getHeaders();
    }

    public String getHeader(String header) {
        Map<String, String> headers = getHeaders();
        return CollectionUtils.isEmpty(headers) ? null : headers.get(header);
    }

    public String getRequestId() {
        NetRequestMessage message = getContextMessage();
        if (message != null) {
            byte[] bytes = message.getMessage();
            if (bytes != null) {
                // 反序列
                try {
                    BaseRequest request = SERIALIZER.deserialize(bytes, BaseRequest.class);
                    return request.getSerialNo();
                } catch (MessageSerializeException e) {
                    // 反序列化异常
                }
            }
        }
        if (log.isWarnEnabled()) {
            log.warn("Not found request context, generate random log id.");
        }
        return SerialNoUtils.nextId();
    }

    private NetRequestMessage getContextMessage() {
        // 先取请求上下文
        NetRequestMessage message = NetRequestContextInterceptor.REQUEST_CONTEXT_MAP.get();
        if (message == null) {
            message = AsyncTaskWorkInterceptor.TASK_CONTEXT_MAP.get();
        }
        return message;
    }

    public class NetRequestContextInterceptor implements RpcHandleInterceptor {

        static ThreadLocal<NetRequestMessage> REQUEST_CONTEXT_MAP = new InheritableThreadLocal<>();

        @Override
        public void before(NetRequestMessage message) throws NetHandleInterceptException {
            REQUEST_CONTEXT_MAP.set(message);
        }

        @Override
        public void after(NetResponseMessage message) throws NetHandleInterceptException {
            // Nothing.
        }

        @Override
        public void doLast(MessageEvent event) {
            REQUEST_CONTEXT_MAP.remove();
        }
    }

    public class AsyncTaskWorkInterceptor implements RelayJobAsyncWorker.WorkInterceptor {

        static ThreadLocal<NetRequestMessage> TASK_CONTEXT_MAP = new InheritableThreadLocal<>();

        private NetRequestMessage message;

        public AsyncTaskWorkInterceptor() {
            // 先查找网络请求上下文
            message = NetRequestContextInterceptor.REQUEST_CONTEXT_MAP.get();
            if (message == null) {
                // 再查找异步任务上下文
                message = TASK_CONTEXT_MAP.get();
            }
        }

        @Override
        public void beforeWork() {
            // 注入当前任务上下文信息
            TASK_CONTEXT_MAP.set(message);
        }

        @Override
        public void afterWork() {
            // 释放线程上下文信息
            TASK_CONTEXT_MAP.remove();
        }
    }

}
